{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "from tensorflow.keras import layers\n",
    "from tensorflow.keras import regularizers\n",
    "\n",
    "import tensorflow_docs as tfdocs\n",
    "import tensorflow_docs.modeling\n",
    "import tensorflow_docs.plots\n",
    "\n",
    "from  IPython import display\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "import pathlib\n",
    "import shutil\n",
    "import tempfile"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "logdir = pathlib.Path(tempfile.mkdtemp())/\"tensorboard_logs\"\n",
    "shutil.rmtree(logdir, ignore_errors=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gz = tf.keras.utils.get_file('HIGGS.csv.gz', 'https://archive.ics.uci.edu/ml/machine-learning-databases/00280/HIGGS.csv.gz')\n",
    "# 可读取 gzip 压缩的 CSV 文件，无解压步骤，返回每个记录的标量的列表\n",
    "ds = tf.data.experimental.CsvDataset(gz,[float(),]*(FEATURES+1), compression_type=\"GZIP\")\n",
    "\n",
    "# 将记录的标量的列表转换为 (feature_vector, label) 对\n",
    "def pack_row(*row):\n",
    "  label = row[0]\n",
    "  features = tf.stack(row[1:],1)\n",
    "  return features, label\n",
    "\n",
    "'''\n",
    "创建一个新的数据集，该数据集采用10000个示例的批次，对每个批次应用pack_row函数，然后将批次拆分回单独的记录：\n",
    "'''\n",
    "packed_ds = ds.batch(10000).map(pack_row).unbatch()\n",
    "for features,label in packed_ds.batch(1000).take(1):\n",
    "  print(features[0])\n",
    "  plt.hist(features.numpy().flatten(), bins = 101)\n",
    "    \n",
    "N_VALIDATION = int(1e3)\n",
    "N_TRAIN = int(1e4)\n",
    "BUFFER_SIZE = int(1e4)\n",
    "BATCH_SIZE = 500\n",
    "STEPS_PER_EPOCH = N_TRAIN//BATCH_SIZE\n",
    "\n",
    "validate_ds = packed_ds.take(N_VALIDATION).cache()\n",
    "train_ds = packed_ds.skip(N_VALIDATION).take(N_TRAIN).cache()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6rc1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
